// Copyright 2024 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.

// Generate KVM ARM64 register IDs for dev_kvm.txt
// Usage:
//
//	go run registers.go msr_mrs.txt
package main

import (
	"bytes"
	"fmt"
	"os"
	"regexp"
	"strconv"
	"strings"

	"github.com/google/syzkaller/pkg/tool"
)

func main() {
	if len(os.Args) != 2 {
		tool.Failf("usage: gen msr_mrs.txt")
	}
	input, err := os.ReadFile(os.Args[1])
	if err != nil {
		tool.Failf("failed to open input file: %v", err)
	}

	fmt.Printf("# Register descriptions generated by tools/arm64/registers.go\n")
	printSysRegIDs(input)
	printCoreRegs()
	fmt.Printf("# End of register descriptions generated by tools/arm64/registers.go\n")
}

// Process input lines and return a string containing the list of corresponding register IDs.
func printSysRegIDs(table []byte) {
	ret := ""
	for _, line := range bytes.Split(table, []byte("\n")) {
		if bytes.HasPrefix(line, []byte("#")) {
			continue
		}

		lineStr := strings.TrimSpace(string(line))
		if lineStr == "" {
			continue
		}
		expandedLines := expandLine(lineStr)
		for _, eline := range expandedLines {
			value, err := processLine(eline)
			if err == nil {
				if ret != "" {
					ret += ", "
				}
				ret += fmt.Sprintf("0x%x", value)
			} else {
				fmt.Fprintf(os.Stdout, "%v\n", err)
			}
		}
	}
	fmt.Printf("kvm_regs_arm64_sys = %s\n", ret)
}

// Process a single line of the following form:
//
//	`0b10    0b000   0b0000  0b0010  0b000   MDCCINT_EL1 ...`
//
// or
//
//	`0b00    0b000   0b0100  -       0b101   SPSel ...`
//
// - extract five operands from it (treat "-" as a zero) and generate a register ID from them.
func processLine(line string) (int64, error) {
	fields := strings.Fields(line)

	if len(fields) < 6 {
		return 0, fmt.Errorf("line has too few fields: %s", line)
	}

	var operands []int
	for i := 0; i < 5; i++ {
		if fields[i] != "-" {
			val, err := strconv.ParseInt(strings.TrimPrefix(fields[i], "0b"), 2, 64)
			if err != nil {
				return 0, fmt.Errorf("conversion error: %w", err)
			}
			operands = append(operands, int(val))
		} else {
			operands = append(operands, 0)
		}
	}
	id := arm64KVMRegID(operands)

	return id, nil
}

// If a line contains bit wildcards, replace them with all possible bit permutations.
//
// E.g. the following line:
//
//	`0b11    0b100   0b1100  0b1000  0b0:n[1:0]      ICH_AP0R<n>_EL2 ...`
//
// will be expanded to:
//
//	`0b11    0b100   0b1100  0b1000  0b000      ICH_AP0R<n>_EL2 ...`
//	`0b11    0b100   0b1100  0b1000  0b001      ICH_AP0R<n>_EL2 ...`
//	`0b11    0b100   0b1100  0b1000  0b010      ICH_AP0R<n>_EL2 ...`
//	`0b11    0b100   0b1100  0b1000  0b011      ICH_AP0R<n>_EL2 ...`
func expandLine(line string) []string {
	re := regexp.MustCompile(`(:)?n\[(\d+)(:(\d+))?\]`)
	match := re.FindStringSubmatch(line)
	if match == nil {
		return []string{line}
	}

	prefix := "0b"
	// If n[] is preceded by ":", there is a 0b prefix in front of it already.
	if match[1] == ":" {
		prefix = ""
	}
	start, _ := strconv.Atoi(match[2])
	end := start
	if match[3] != "" {
		end, _ = strconv.Atoi(match[4])
	}
	m := start - end + 1
	numPermutations := 1 << m

	expandedLines := make([]string, 0, numPermutations)
	for i := 0; i < numPermutations; i++ {
		bits := fmt.Sprintf("%s%0*b", prefix, m, i)
		newLine := strings.Replace(line, match[0], bits, 1)
		if strings.Contains(newLine, "n[") {
			secondary := expandLine(newLine)
			expandedLines = append(expandedLines, secondary...)
		} else {
			expandedLines = append(expandedLines, newLine)
		}
	}

	return expandedLines
}

const (
	// Constants from https://elixir.bootlin.com/linux/v6.10.2/source/arch/arm64/include/uapi/asm/kvm.h
	kvmRegArmCoprocShift       = 16
	kvmRegArm64Sysreg    int64 = (0x0013 << kvmRegArmCoprocShift)
	kvmRegSizeU64        int64 = 0x0030000000000000
	kvmRegArm64          int64 = 0x6000000000000000
)

// Generate register ID from Op0, Op1, CRn, CRm, Op2.
// See https://elixir.bootlin.com/linux/v6.10.2/source/arch/arm64/include/uapi/asm/kvm.h#L257 for more details.
func arm64KVMRegID(operands []int) int64 {
	shifts := [5]int64{14, 11, 7, 3, 0}
	ret := kvmRegSizeU64 | kvmRegArm64 | kvmRegArm64Sysreg
	for i := 0; i < 5; i++ {
		ret |= (int64(operands[i]) << shifts[i])
	}
	return ret
}

// Generate core register IDs.
// See https://docs.kernel.org/virt/kvm/api.html for more details.
func printCoreRegs() {
	fmt.Printf("# Extra registers that KVM_GET_REG_LIST prints on QEMU\n")
	// Some of these register IDs do not have corresponding registers, yet the kernel returns them.
	// TODO(glider): figure out why this is happening.
	fmt.Printf("kvm_regs_arm64_extra = 0x603000000013c01b, 0x603000000013c01f, 0x603000000013c022, 0x603000000013c023, " +
		"0x603000000013c025, 0x603000000013c026, 0x603000000013c027, 0x603000000013c02a, 0x603000000013c02b, " +
		"0x603000000013c02e, 0x603000000013c02f, 0x603000000013c033, 0x603000000013c034, 0x603000000013c035, " +
		"0x603000000013c036, 0x603000000013c037, 0x603000000013c03b, 0x603000000013c03c, 0x603000000013c03d, " +
		"0x603000000013c03e, 0x603000000013c03f, 0x603000000013c103, 0x603000000013c512, 0x603000000013c513\n")
}
